import os
import gzip
import time

import numpy as np

import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import seaborn as sns

from sklearn.manifold import MDS
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.base import BaseEstimator, TransformerMixin

from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import linkage, fcluster, dendrogram

import pandas as pd

import gseapy as gp
from gseapy import barplot, dotplot, Biomart

mkr = ['HOMER3','CD14','IFI27','ZEB2','IL18R1','DAAM2','GBP5','NKG7','PATL2','DGKH','SLAMF7','STAT1']

# Public data
keys_tags = ['!Sample_title', '!Sample_geo_accession', '!Sample_characteristics_ch1']
aux = []
with gzip.open('GSE120649_series_matrix.txt.gz') as gse_f:
    for line in gse_f:
        if line.decode().split('\t')[0] in keys_tags:
            aux.append(line.decode().replace('\n','').replace('"','').split('\t'))
            
df_GSE120649_meta = pd.DataFrame(aux).set_index(0).T.iloc[:,:2]
df_GSE120649_meta.columns = ['Title', 'Sample Name']
df_GSE120649_meta['abmrh'] = np.where(df_GSE120649_meta.Title.str.contains('with antibody mediated rejection'), 'ABMR', '')
df_GSE120649_meta['tcmr'] = np.where(df_GSE120649_meta.Title.str.contains('with T cell mediated rejection'), 'TCMR', '')
df_GSE120649_meta['ar'] = df_GSE120649_meta.abmrh + df_GSE120649_meta.tcmr
df_GSE120649_meta['ar'] = np.where(df_GSE120649_meta['ar']=='','Non_ar',df_GSE120649_meta['ar'])
df_GSE120649_meta = df_GSE120649_meta.merge(pd.read_csv('SraRunTable.txt')[['Run','Sample Name']], on='Sample Name')
df_GSE120649_meta = df_GSE120649_meta.set_index('Run')[['abmrh','tcmr','ar']]

df_GSE120649_meta['gse'] = 'GSE120649'

tx2gene = pd.read_csv('gene_id2gene_nameWgene_lenght_biotype_gencode.v44.annotation.csv')

# Xeno metadata
df_xeno_meta = pd.read_csv('/mnt/d/xeno/xeno_bulk_metadata.csv', header=None)
df_xeno_meta[3] = df_xeno_meta[3].str.replace('_R1_001','')

#Expression data
salmon_paths = ['posttransplant_PRJNA493832_GSE120649/trimmed_salmon_quant/',
                'xeno_bulk/trimmed_salmon_quant/']
tpm = []
counts_tpm = []
vsd = []

for path in salmon_paths:
    tpm_aux = []
    counts_aux = []
    srr_list = os.listdir(path)
    print(path.split('/')[3], end=': ')
    for j,srr in enumerate(srr_list):
        df_aux = pd.read_csv(os.path.join(path,srr,'quant.sf'), sep='\t')[['Name','TPM','EffectiveLength','NumReads']].merge(tx2gene[['Transcript_ID','Gene_name']], left_on='Name', right_on='Transcript_ID')
        df_aux = df_aux.groupby('Gene_name').sum(numeric_only=True)  
        df_aux['counts_tpm'] = np.round(df_aux.NumReads,0).astype(int)
        tpm_aux.append(df_aux.TPM.values)
        counts_aux.append(df_aux.counts_tpm.values)
        print(j+1, end=' ')
    tpm.append(pd.DataFrame(tpm_aux, columns=df_aux.index, index=srr_list))
    counts_tpm.append(pd.DataFrame(counts_aux, columns=df_aux.index, index=srr_list))
    
    x = np.array(counts_tpm[-1].loc[:,counts_tpm[-1].sum(axis=0)>0].values.astype('float64'), dtype='float64')
    with np.errstate(invalid='ignore', divide='ignore'):
        x /= x.mean(axis=1)[:, np.newaxis]
        x = np.log(x)
    x[~np.isfinite(x)] = np.nan
    nf = np.nanmedian(x, axis=0)
    nf -= nf.mean()
    nf = np.exp(nf)

    vsd.append(x / np.array(nf).flatten())
    vsd[-1] = pd.DataFrame(vsd[-1], columns=counts_tpm[-1].loc[:,counts_tpm[-1].sum(axis=0)>0].columns, index=counts_tpm[-1].index)
    
    print('\n')
    
df_tpm = pd.concat(tpm, axis=0)
df_counts = pd.concat(counts_tpm, axis=0)
df_vsd = pd.concat(vsd, axis=0)

df_meta = pd.concat([df_GSE120649_meta[['ar']], df_xeno_meta.set_index(3)[['ar']]])
df_meta['ar'] = df_meta.ar.str.replace('ABMR','Allorejection')
df_meta['ar'] = df_meta.ar.str.replace('TCMR','Allorejection')
df_meta['ar'] = df_meta.ar.str.replace('Non_ar','Allo-Non-rejection')
df_meta['ar'] = df_meta.ar.str.replace('D7','Xenorejection (D7)')
df_meta['ar'] = df_meta.ar.str.replace('D1','Xeno-Non-rejection (D1)')
df_meta['ar'] = df_meta.ar.str.replace('D26','Xeno-Non-rejection (D26)')

cond = [df_meta.ar.str.contains('orejection'), df_meta.ar.str.contains('n-rejection')]
choice = ['Rejection','Non-rejection']
df_meta['Diagnosis_'] = np.select(condlist=cond, choicelist=choice, default='Pre')

color_ar = df_meta.copy()
color_ar['Diagnosis'] = color_ar['Diagnosis_'].map({'Non-rejection':'#7f7f7f','Rejection':'#ff7f0e','Pre':'black',
                              })

df_mrn = df_counts.copy()
scaler = StandardScaler()
mrn = hkg.pp.MRN_transformer()
df_mrn.iloc[:-5,:] = scaler.fit_transform(mrn.transform(df_mrn.iloc[:-5,:]))

scaler = StandardScaler()
mrn = hkg.pp.MRN_transformer()
df_mrn.iloc[-5:,:] = scaler.fit_transform(mrn.transform(df_mrn.iloc[-5:,:]))

data = df_mrn.loc[~df_mrn.index.isin(['RGc21_S1','RGc22_S2','RGc23_S3']), mkr]
color_ar = color_ar[~color_ar.index.isin(['RGc21_S1','RGc22_S2','RGc23_S3'])]

g = sns.clustermap(data=data, metric='correlation', method='average', row_colors=color_ar['Diagnosis'], cmap='bwr',yticklabels=1,#z_score=0,
                   cbar_pos=[1.01,0.38,.25,.03], cbar_kws={"orientation": "horizontal"},figsize=(5,5), linewidth=.7)
g.ax_heatmap.set_xlabel('')
g.ax_cbar.set_title('z-score', fontsize=8)
g.ax_heatmap.axhline(8, color='black', ls='--')
g.ax_heatmap.axhline(16, color='black', ls='--')

lut = {'Non-rejection':'#7f7f7f','Rejection':'#ff7f0e'}
handles = [Patch(facecolor=lut[name]) for name in lut]
plt.legend(handles, lut, title='Diagnosis',
           bbox_to_anchor=(1.29, .75), bbox_transform=plt.gcf().transFigure, loc='upper right')

new_labels = []
new_label_dict = dict(color_ar['ar'])
for item in g.ax_heatmap.get_yticklabels():
    item.set_text(new_label_dict[item.get_text()].replace('-Non-rejection','').replace('rejection',''))
    new_labels.append(item)

g.ax_heatmap.axes.set_yticklabels(new_labels)